{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using MXNet DALI plugin: using various readers\n",
    "\n",
    "### Overview\n",
    "\n",
    "This example shows how different readers could be used to interact with MXNet. It shows how flexible DALI is.\n",
    "\n",
    "The following readers are used in this example:\n",
    "\n",
    "- MXNetReader\n",
    "- CaffeReader\n",
    "- FileReader\n",
    "- TFRecordReader\n",
    "\n",
    "For details on how to use them please see other [examples](../../index.rst)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us start from defining some global constants\n",
    "\n",
    "`DALI_EXTRA_PATH` environment variable should point to the place where data from [DALI extra repository](https://github.com/NVIDIA/DALI_extra) is downloaded. Please make sure that the proper release tag is checked out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path\n",
    "\n",
    "test_data_root = os.environ['DALI_EXTRA_PATH']\n",
    "\n",
    "# MXNet RecordIO\n",
    "db_folder = os.path.join(test_data_root, 'db', 'recordio/')\n",
    "\n",
    "# Caffe LMDB\n",
    "lmdb_folder = os.path.join(test_data_root, 'db', 'lmdb')\n",
    "\n",
    "# image dir with plain jpeg files\n",
    "image_dir = \"../../data/images\"\n",
    "\n",
    "# TFRecord\n",
    "tfrecord = os.path.join(test_data_root, 'db', 'tfrecord', 'train')\n",
    "tfrecord_idx = \"idx_files/train.idx\"\n",
    "tfrecord2idx_script = \"tfrecord2idx\"\n",
    "\n",
    "N = 8             # number of GPUs\n",
    "BATCH_SIZE = 128  # batch size per GPU\n",
    "ITERATIONS = 32\n",
    "IMAGE_SIZE = 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create idx file by calling `tfrecord2idx` script"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from subprocess import call\n",
    "import os.path\n",
    "\n",
    "if not os.path.exists(\"idx_files\"):\n",
    "    os.mkdir(\"idx_files\")\n",
    "\n",
    "if not os.path.isfile(tfrecord_idx):\n",
    "    call([tfrecord2idx_script, tfrecord, tfrecord_idx])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us define:\n",
    "- common part of pipeline, other pipelines will inherit it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nvidia.dali.pipeline import Pipeline\n",
    "import nvidia.dali.ops as ops\n",
    "import nvidia.dali.types as types\n",
    "\n",
    "class CommonPipeline(Pipeline):\n",
    "    def __init__(self, batch_size, num_threads, device_id):\n",
    "        super(CommonPipeline, self).__init__(batch_size, num_threads, device_id)\n",
    "\n",
    "        self.decode = ops.ImageDecoder(device = \"mixed\", output_type = types.RGB)\n",
    "        self.resize = ops.Resize(device = \"gpu\",\n",
    "                                 image_type = types.RGB,\n",
    "                                 interp_type = types.INTERP_LINEAR)\n",
    "        self.cmn = ops.CropMirrorNormalize(device = \"gpu\",\n",
    "                                            output_dtype = types.FLOAT,\n",
    "                                            crop = (227, 227),\n",
    "                                            image_type = types.RGB,\n",
    "                                            mean = [128., 128., 128.],\n",
    "                                            std = [1., 1., 1.])\n",
    "        self.uniform = ops.Uniform(range = (0.0, 1.0))\n",
    "        self.resize_rng = ops.Uniform(range = (256, 480))\n",
    "\n",
    "    def base_define_graph(self, inputs, labels):\n",
    "        images = self.decode(inputs)\n",
    "        images = self.resize(images, resize_shorter = self.resize_rng())\n",
    "        output = self.cmn(images, crop_pos_x = self.uniform(),\n",
    "                          crop_pos_y = self.uniform())\n",
    "        return (output, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- MXNetReaderPipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nvidia.dali.pipeline import Pipeline\n",
    "import nvidia.dali.ops as ops\n",
    "import nvidia.dali.types as types\n",
    "\n",
    "class MXNetReaderPipeline(CommonPipeline):\n",
    "    def __init__(self, batch_size, num_threads, device_id, num_gpus):\n",
    "        super(MXNetReaderPipeline, self).__init__(batch_size, num_threads, device_id)\n",
    "        self.input = ops.MXNetReader(path = [db_folder+\"train.rec\"], index_path=[db_folder+\"train.idx\"],\n",
    "                                     random_shuffle = True, shard_id = device_id, num_shards = num_gpus)\n",
    "\n",
    "    def define_graph(self):\n",
    "        images, labels = self.input(name=\"Reader\")\n",
    "        return self.base_define_graph(images, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- CaffeReadPipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CaffeReadPipeline(CommonPipeline):\n",
    "    def __init__(self, batch_size, num_threads, device_id, num_gpus):\n",
    "        super(CaffeReadPipeline, self).__init__(batch_size, num_threads, device_id)\n",
    "        self.input = ops.CaffeReader(path = lmdb_folder,\n",
    "                                     random_shuffle = True, shard_id = device_id, num_shards = num_gpus)\n",
    "\n",
    "    def define_graph(self):\n",
    "        images, labels = self.input(name=\"Reader\")\n",
    "        return self.base_define_graph(images, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- FileReadPipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FileReadPipeline(CommonPipeline):\n",
    "        def __init__(self, batch_size, num_threads, device_id, num_gpus):\n",
    "            super(FileReadPipeline, self).__init__(batch_size, num_threads, device_id)\n",
    "            self.input = ops.FileReader(file_root = image_dir)\n",
    "\n",
    "        def define_graph(self):\n",
    "            images, labels = self.input(name=\"Reader\")\n",
    "            return self.base_define_graph(images, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- TFRecordPipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nvidia.dali.tfrecord as tfrec\n",
    "\n",
    "class TFRecordPipeline(CommonPipeline):\n",
    "    def __init__(self, batch_size, num_threads, device_id, num_gpus):\n",
    "        super(TFRecordPipeline, self).__init__(batch_size, num_threads, device_id)\n",
    "        self.input = ops.TFRecordReader(path = tfrecord, \n",
    "                                        index_path = tfrecord_idx,\n",
    "                                        features = {\"image/encoded\" : tfrec.FixedLenFeature((), tfrec.string, \"\"),\n",
    "                                                    \"image/class/label\": tfrec.FixedLenFeature([1], tfrec.int64,  -1)\n",
    "                                        })\n",
    "\n",
    "    def define_graph(self):\n",
    "        inputs = self.input(name=\"Reader\")\n",
    "        images = inputs[\"image/encoded\"]\n",
    "        labels = inputs[\"image/class/label\"]\n",
    "        return self.base_define_graph(images, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us create pipelines and pass them to MXNet generic iterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RUN: MXNetReaderPipeline\n",
      "OK : MXNetReaderPipeline\n",
      "RUN: CaffeReadPipeline\n",
      "OK : CaffeReadPipeline\n",
      "RUN: FileReadPipeline\n",
      "OK : FileReadPipeline\n",
      "RUN: TFRecordPipeline\n",
      "OK : TFRecordPipeline\n"
     ]
    }
   ],
   "source": [
    "from __future__ import print_function\n",
    "import numpy as np\n",
    "from nvidia.dali.plugin.mxnet import DALIGenericIterator\n",
    "\n",
    "pipe_types = [[MXNetReaderPipeline, (0, 999)], \n",
    "              [CaffeReadPipeline, (0, 999)],\n",
    "              [FileReadPipeline, (0, 1)], \n",
    "              [TFRecordPipeline, (1, 1000)]]\n",
    "for pipe_t in pipe_types:\n",
    "    pipe_name, label_range = pipe_t\n",
    "    print (\"RUN: \"  + pipe_name.__name__)\n",
    "    pipes = [pipe_name(batch_size=BATCH_SIZE, num_threads=2, device_id = device_id, num_gpus = N) for device_id in range(N)]\n",
    "    pipes[0].build()\n",
    "    dali_iter = DALIGenericIterator(pipes, [('data', DALIGenericIterator.DATA_TAG),\n",
    "                                            ('label', DALIGenericIterator.LABEL_TAG)],\n",
    "                                    pipes[0].epoch_size(\"Reader\"))\n",
    "\n",
    "    for i, data in enumerate(dali_iter):\n",
    "        if i >= ITERATIONS:\n",
    "            break\n",
    "        # Testing correctness of labels\n",
    "        for d in data:\n",
    "            label = d.label[0].asnumpy()\n",
    "            image = d.data[0]\n",
    "            ## labels need to be integers\n",
    "            assert(np.equal(np.mod(label, 1), 0).all())\n",
    "            ## labels need to be in range pipe_name[2]\n",
    "            assert((label >= label_range[0]).all())\n",
    "            assert((label <= label_range[1]).all())\n",
    "    print(\"OK : \" + pipe_name.__name__)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
